{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Use case proposed in [issue #51](https://github.com/bsorrentino/langgraph4j/issues/51) by [pakamona](https://github.com/pakamona)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "var userHomeDir = System.getProperty(\"user.home\");\n",
    "var localRespoUrl = \"file://\" + userHomeDir + \"/.m2/repository/\";\n",
    "var langchain4jVersion = \"1.0.1\";\n",
    "var langchain4jbeta = \"1.0.1-beta6\";\n",
    "var langgraph4jVersion = \"1.6-SNAPSHOT\";"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%dependency /add-repo local \\{localRespoUrl} release|never snapshot|always\n",
    "// %dependency /list-repos\n",
    "%dependency /add org.slf4j:slf4j-jdk14:2.0.9\n",
    "%dependency /add org.bsc.langgraph4j:langgraph4j-core:\\{langgraph4jVersion}\n",
    "%dependency /add org.bsc.langgraph4j:langgraph4j-langchain4j:\\{langgraph4jVersion}\n",
    "%dependency /add org.bsc.langgraph4j:langgraph4j-agent-executor:\\{langgraph4jVersion}\n",
    "%dependency /add dev.langchain4j:langchain4j:\\{langchain4jVersion}\n",
    "%dependency /add dev.langchain4j:langchain4j-open-ai:\\{langchain4jVersion}\n",
    "\n",
    "%dependency /resolve"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Initialize Logger**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "try( var file = new java.io.FileInputStream(\"./logging.properties\")) {\n",
    "    java.util.logging.LogManager.getLogManager().readConfiguration( file );\n",
    "}\n",
    "\n",
    "var log = org.slf4j.LoggerFactory.getLogger(\"AdaptiveRag\");\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "import org.bsc.langgraph4j.state.AgentState;\n",
    "import java.util.Optional;\n",
    "\n",
    "public class MyAgentState extends AgentState {\n",
    "\n",
    "    public MyAgentState(Map<String,Object> initData) {\n",
    "        super(initData);\n",
    "    }\n",
    "\n",
    "    Optional<String> input() {\n",
    "        return value( \"input\" );\n",
    "    }\n",
    "    Optional<String> orchestratorOutcome() { \n",
    "        return value( \"orchestrator_outcome\" );\n",
    "    }\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "import org.bsc.langgraph4j.state.AgentState;\n",
    "import org.bsc.langgraph4j.action.NodeAction;\n",
    "\n",
    "import java.util.Map;\n",
    "import java.util.ArrayList;\n",
    "import java.util.concurrent.CompletableFuture;\n",
    "import java.util.function.Function;\n",
    "\n",
    "import dev.langchain4j.data.message.ChatMessage;\n",
    "import dev.langchain4j.data.message.SystemMessage;\n",
    "import dev.langchain4j.data.message.UserMessage;\n",
    "\n",
    "import dev.langchain4j.model.chat.ChatModel;\n",
    "import dev.langchain4j.model.input.PromptTemplate;\n",
    "import dev.langchain4j.model.output.Response;\n",
    "import dev.langchain4j.model.chat.request.ChatRequest;\n",
    "\n",
    "class OrchestratorAgent implements NodeAction<MyAgentState> {\n",
    "\n",
    "    private final ChatModel chatModel;\n",
    "\n",
    "    public OrchestratorAgent( ChatModel chatModel ) {\n",
    "        this.chatModel = chatModel;\n",
    "    }\n",
    " \n",
    "    public Map<String, Object> apply(MyAgentState state) throws Exception {\n",
    "        \n",
    "        var input = state.input().orElseThrow( () -> new IllegalArgumentException(\"input is not provided!\"));\n",
    "\n",
    "        var userMessageTemplate = PromptTemplate.from(\"{{input}}\").apply(Map.of(\"input\", input));\n",
    "        \n",
    "        var messages = new ArrayList<ChatMessage>();\n",
    "        \n",
    "        messages.add(new SystemMessage(\"\"\"\n",
    "        You are a helpful assistant. Evaluate the user request and if the request concerns a story return 'story_teller' otherwise 'greeting'\n",
    "        \"\"\"));\n",
    "        messages.add(new UserMessage(userMessageTemplate.text()));\n",
    "\n",
    "        var request = ChatRequest.builder()\n",
    "                .messages( messages )\n",
    "                .build();\n",
    "       var result = model.chat(request );\n",
    "\n",
    "        return Map.of( \"orchestrator_outcome\", result.aiMessage().text() );\n",
    "    }\n",
    "\n",
    "};\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "import org.bsc.langgraph4j.action.EdgeAction;\n",
    "\n",
    "class RouteOrchestratorOutcome implements EdgeAction<MyAgentState> {\n",
    "\n",
    "    public String apply(MyAgentState state) throws Exception {\n",
    "        \n",
    "        var orchestrationOutcome = state.orchestratorOutcome()\n",
    "                                        .orElseThrow( () -> new IllegalArgumentException(\"orchestration outcome is not provided!\"));\n",
    "\n",
    "        return orchestrationOutcome;\n",
    "    }\n",
    "\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "class StoryTellerAgent implements NodeAction<MyAgentState> {\n",
    "\n",
    "    public Map<String, Object> apply(MyAgentState state) throws Exception {\n",
    "        log.info( \"Story Teller Agent invoked\");\n",
    "        return Map.of();\n",
    "    }\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "class GreetingAgent implements NodeAction<MyAgentState> {\n",
    "\n",
    "    public Map<String, Object> apply(MyAgentState state) throws Exception {\n",
    "        log.info( \"Greeting Agent invoked\");\n",
    "        return Map.of();\n",
    "    }\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "import static org.bsc.langgraph4j.action.AsyncNodeAction.node_async;\n",
    "import static org.bsc.langgraph4j.action.AsyncEdgeAction.edge_async;\n",
    "import org.bsc.langgraph4j.StateGraph;\n",
    "import static org.bsc.langgraph4j.StateGraph.START;\n",
    "import static org.bsc.langgraph4j.StateGraph.END;\n",
    "import dev.langchain4j.model.openai.OpenAiChatModel;\n",
    "\n",
    "var model = OpenAiChatModel.builder()\n",
    "        .apiKey( System.getenv(\"OPENAI_API_KEY\")  )\n",
    "        .modelName( \"gpt-4o-mini\" )\n",
    "        .logResponses(true)\n",
    "        .maxRetries(2)\n",
    "        .temperature(0.0)\n",
    "        .maxTokens(2000)\n",
    "        .build();\n",
    "\n",
    "var orchestratorAgent = node_async( new OrchestratorAgent(model) );\n",
    "var storyTellerAgent = node_async(new StoryTellerAgent());\n",
    "var greetingAgent = node_async(new GreetingAgent());\n",
    "var routeOrchestratorOutcome = edge_async( new RouteOrchestratorOutcome() );\n",
    "\n",
    "var workflow = new StateGraph<>( MyAgentState::new ) \n",
    "                .addNode(\"orchestrator_agent\", orchestratorAgent  )\n",
    "                .addNode(\"story_teller_agent\", storyTellerAgent )\n",
    "                .addNode(\"greetings_agent\", greetingAgent )\n",
    "                .addConditionalEdges(\"orchestrator_agent\",\n",
    "                        routeOrchestratorOutcome,\n",
    "                        Map.of( \"story_teller\", \"story_teller_agent\",\n",
    "                                \"greeting\", \"greetings_agent\" ))\n",
    "                .addEdge(START, \"orchestrator_agent\")\n",
    "                .addEdge(\"story_teller_agent\", END)\n",
    "                .addEdge(\"greetings_agent\", END);\n",
    "\n",
    "var app = workflow.compile();        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "START \n",
      "NodeOutput{node=__START__, state={input=tell me a xmas story}} \n",
      "Story Teller Agent invoked \n",
      "NodeOutput{node=orchestrator_agent, state={input=tell me a xmas story, orchestrator_outcome=story_teller}} \n",
      "NodeOutput{node=story_teller_agent, state={input=tell me a xmas story, orchestrator_outcome=story_teller}} \n",
      "NodeOutput{node=__END__, state={input=tell me a xmas story, orchestrator_outcome=story_teller}} \n"
     ]
    }
   ],
   "source": [
    "\n",
    "\n",
    "for( var node : app.stream( Map.of( \"input\", \"tell me a xmas story\"))) {\n",
    "    log.info( \"{}\", node );\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "START \n",
      "NodeOutput{node=__START__, state={input=hi there}} \n",
      "Greeting Agent invoked \n",
      "NodeOutput{node=orchestrator_agent, state={input=hi there, orchestrator_outcome=greeting}} \n",
      "NodeOutput{node=greetings_agent, state={input=hi there, orchestrator_outcome=greeting}} \n",
      "NodeOutput{node=__END__, state={input=hi there, orchestrator_outcome=greeting}} \n"
     ]
    }
   ],
   "source": [
    "for( var node : app.stream( Map.of( \"input\", \"hi there\"))) {\n",
    "    log.info( \"{}\", node );\n",
    "}"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Java (rjk 2.2.0)",
   "language": "java",
   "name": "rapaio-jupyter-kernel"
  },
  "language_info": {
   "codemirror_mode": "java",
   "file_extension": ".jshell",
   "mimetype": "text/x-java-source",
   "name": "java",
   "nbconvert_exporter": "script",
   "pygments_lexer": "java",
   "version": "22.0.2+9-70"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
